/**
 * 
 */
package edu.berkeley.nlp.discPCFG;

import java.util.List;

import edu.berkeley.nlp.PCFGLA.BinaryRule;
import edu.berkeley.nlp.PCFGLA.ConditionalTrainer;
import edu.berkeley.nlp.PCFGLA.Grammar;
import edu.berkeley.nlp.PCFGLA.HierarchicalAdaptiveBinaryRule;
import edu.berkeley.nlp.PCFGLA.HierarchicalAdaptiveGrammar;
import edu.berkeley.nlp.PCFGLA.HierarchicalAdaptiveLexicalRule;
import edu.berkeley.nlp.PCFGLA.HierarchicalAdaptiveUnaryRule;
import edu.berkeley.nlp.PCFGLA.HierarchicalBinaryRule;
import edu.berkeley.nlp.PCFGLA.HierarchicalFullyConnectedAdaptiveLexicon;
import edu.berkeley.nlp.PCFGLA.HierarchicalGrammar;
import edu.berkeley.nlp.PCFGLA.HierarchicalUnaryRule;
import edu.berkeley.nlp.PCFGLA.SimpleLexicon;
import edu.berkeley.nlp.PCFGLA.SpanPredictor;
import edu.berkeley.nlp.PCFGLA.UnaryRule;
import edu.berkeley.nlp.syntax.StateSet;
import edu.berkeley.nlp.syntax.StateSetWithFeatures;
import edu.berkeley.nlp.util.ArrayUtil;
import edu.berkeley.nlp.math.DoubleArrays;
import edu.berkeley.nlp.math.SloppyMath;

/**
 * @author petrov
 *
 */
public class HiearchicalAdaptiveLinearizer extends HierarchicalLinearizer {
	private static final long serialVersionUID = 1L;
	
	HierarchicalAdaptiveGrammar grammar;
	HierarchicalFullyConnectedAdaptiveLexicon lexicon;

	public HiearchicalAdaptiveLinearizer(Grammar grammar, SimpleLexicon lexicon, SpanPredictor sp, int fLevel) {
		this.grammar = (HierarchicalAdaptiveGrammar)grammar;
		lexicon.explicitlyComputeScores(fLevel);
		grammar.closedSumRulesWithParent = grammar.closedViterbiRulesWithParent = grammar.unaryRulesWithParent;
		grammar.closedSumRulesWithChild = grammar.closedViterbiRulesWithChild = grammar.unaryRulesWithC;
		grammar.clearUnaryIntermediates();
		grammar.makeCRArrays();

		this.lexicon = (HierarchicalFullyConnectedAdaptiveLexicon)lexicon;
		this.spanPredictor = sp;
		this.finalLevel = fLevel;
		this.nSubstates = (int)ArrayUtil.max(grammar.numSubStates);
		init();
		computeMappings();

	}
	
	public SimpleLexicon getLexicon() {
		return lexicon;
	}
	

	public Grammar getGrammar() {
		return grammar;
	}


	
	public double[] getLinearizedLexicon(boolean update) {
  	if(update){
  		nLexiconWeights = 0;
	  	for (short tag=0; tag<lexicon.rules.length; tag++){
	  		for (int word=0; word<lexicon.rules[tag].length; word++){
	  			lexicon.rules[tag][word].identifier = nLexiconWeights + nGrammarWeights;
	  			nLexiconWeights += lexicon.rules[tag][word].getFinalLevel().size(); //lexicon.rules[tag][word].nParam;
	  		}
	  	}
  	}
  	double[] logProbs = new double[nLexiconWeights];
//  	if (update) linearIndex = new int[lexicon.rules.length][];

  	int index = 0;
  	for (short tag=0; tag<lexicon.rules.length; tag++){
//  		if (update) linearIndex[tag] = new int[lexicon.rules[tag].length];
  		for (int word=0; word<lexicon.rules[tag].length; word++){
//  			if (update) linearIndex[tag][word] = index + nGrammarWeights;
  			List<Double> vals = lexicon.rules[tag][word].getFinalLevel();
  			for (Double val : vals){
  				logProbs[index++] = val;
  			}
  		}
  	}
		if (index!=logProbs.length)
			System.out.println("unequal length in lexicon");

  	return logProbs;
	}

	public void delinearizeLexicon(double[] logProbs, boolean usingOnlyLastLevel) {
		for (short tag=0; tag<lexicon.rules.length; tag++){
  		for (int word=0; word<lexicon.rules[tag].length; word++){
  			lexicon.rules[tag][word].updateScores(logProbs);
  			lexicon.rules[tag][word].explicitlyComputeScores(finalLevel, usingOnlyLastLevel);
  		}
  	}  	
	}
	

	public void delinearizeLexicon(double[] logProbs) {
		for (short tag=0; tag<lexicon.rules.length; tag++){
  		for (int word=0; word<lexicon.rules[tag].length; word++){
  			lexicon.rules[tag][word].updateScores(logProbs);
  			lexicon.rules[tag][word].explicitlyComputeScores(finalLevel, false);
  		}
  	}  	
	}
	
	public void increment(double[] counts, StateSet stateSet, int tag, double[] weights, boolean isGold) {
		if (!(stateSet instanceof StateSetWithFeatures)){
			int globalSigIndex = stateSet.sigIndex;
			if (globalSigIndex != -1){
				int tagSpecificWordIndex = lexicon.tagWordIndexer[tag].indexOf(globalSigIndex);
				if (tagSpecificWordIndex>=0){
					HierarchicalAdaptiveLexicalRule rule = lexicon.rules[tag][tagSpecificWordIndex];
					int startIndexWord = rule.identifier;
					short[] mapping = rule.mapping;
					for (int i=0; i<nSubstates; i++){
						if (isGold) counts[startIndexWord + mapping[i]] += weights[i];
						else counts[startIndexWord + mapping[i]] -= weights[i];
					}
				}
			}
			int globalWordIndex = stateSet.wordIndex;
			int tagSpecificWordIndex = lexicon.tagWordIndexer[tag].indexOf(globalWordIndex);
			if (tagSpecificWordIndex<0){
				for (int i=0; i<nSubstates; i++){
					weights[i]=0;
				}
			} else {
				HierarchicalAdaptiveLexicalRule rule = lexicon.rules[tag][tagSpecificWordIndex];
				int startIndexWord = rule.identifier;
				short[] mapping = rule.mapping;
				for (int i=0; i<nSubstates; i++){
					if (isGold) counts[startIndexWord + mapping[i]] += weights[i];
					else counts[startIndexWord + mapping[i]] -= weights[i];
					weights[i] = 0;
				}
			}
		} else {
			StateSetWithFeatures stateSetF = (StateSetWithFeatures) stateSet; 
			for (int f : stateSetF.features){
				if (f<0)
					continue;
				int tagF = lexicon.tagWordIndexer[tag].indexOf(f);
				if (tagF<0) 
					continue;
			
				HierarchicalAdaptiveLexicalRule rule = lexicon.rules[tag][tagF];
				int startIndexWord = rule.identifier;
				short[] mapping = rule.mapping;
				for (int i=0; i<nSubstates; i++){
					if (isGold) counts[startIndexWord + mapping[i]] += weights[i];
					else counts[startIndexWord + mapping[i]] -= weights[i];
				}
			}
			for (int i=0; i<nSubstates; i++){
				weights[i] = 0;
			}
		}
	}
	
	public void increment(double[] counts, BinaryRule rule, double[] weights, boolean isGold) {
		HierarchicalAdaptiveBinaryRule hr = (HierarchicalAdaptiveBinaryRule)rule;
		int thisStartIndex = hr.identifier;
		if (true){
			for (int curInd=0; curInd<hr.nParam; curInd++){
				double val = weights[curInd];
				if(val>0){
					weights[curInd]=0;
					if (isGold) counts[thisStartIndex + curInd] += val;
					else counts[thisStartIndex + curInd] -= val;
				}
	//			System.out.println(counts[thisStartIndex + curInd]);
			}
		} else {
			int curInd=0;
			for (int lp = 0; lp < nSubstates; lp++) {
				for (int rp = 0; rp < nSubstates; rp++) {
	//				if (scores[cp]==null) continue; 
					for (int np = 0; np < nSubstates; np++) {
						double val = weights[curInd];
						short mapping[][][] = hr.mapping;
						if (val>0){
							counts[thisStartIndex + mapping[lp][rp][np]] += val;
							weights[curInd]=0;
						}
						curInd++;
					}
				}
			}
		}
	}
	
	public void increment(double[] counts, UnaryRule rule, double[] weights, boolean isGold) {
		HierarchicalAdaptiveUnaryRule hr = (HierarchicalAdaptiveUnaryRule)rule;
		int thisStartIndex = hr.identifier;
		if (true){
//			if (hr.parentState==0)
//				System.out.println("letss ee");
			for (int curInd=0; curInd<hr.nParam; curInd++){
				double val = weights[curInd];
				if(val>0){
					weights[curInd]=0;
					if (isGold) counts[thisStartIndex + curInd] += val;
					else counts[thisStartIndex + curInd] -= val;
				}
	//			System.out.println(counts[thisStartIndex + curInd]);
			}
		} else {
			int curInd = 0;
			if (rule.parentState==-1){
				for (int cp = 0; cp < nSubstates; cp++) {
					double val = weights[curInd];
					short[][] mapping = hr.mapping;
					if (val>0){
						if (isGold) counts[thisStartIndex + mapping[cp][0]] += val;
						else counts[thisStartIndex + mapping[cp][0]] -= val;
						weights[curInd]=0;
					}
					curInd++;
				}
				return;
			}
			
			for (int cp = 0; cp < nSubstates; cp++) {
	//			if (scores[cp]==null) continue; 
				for (int np = 0; np < nSubstates; np++) {
					double val = weights[curInd];
					short[][] mapping = hr.mapping;
					if (val>0){
						if (isGold) counts[thisStartIndex + mapping[cp][np]] += val;
						else counts[thisStartIndex + mapping[cp][np]] -= val;
						weights[curInd]=0;
					}
					curInd++;
				}
			}
		}
	}


	

	public void delinearizeGrammar(double[] probs) {
		int nDangerous = 0;
		for (BinaryRule bRule : grammar.binaryRuleMap.keySet()){
			HierarchicalAdaptiveBinaryRule hRule = (HierarchicalAdaptiveBinaryRule)bRule;
			hRule.updateScores(probs);
		}
		if (nDangerous>0) System.out.println("Left "+nDangerous+" binary rule weights unchanged since the proposed weight was dangerous.");

		nDangerous = 0;
		for (UnaryRule uRule : grammar.unaryRuleMap.keySet()){
			HierarchicalAdaptiveUnaryRule hRule = (HierarchicalAdaptiveUnaryRule)uRule; 
			hRule.updateScores(probs);
		}
		if (nDangerous>0) System.out.println("Left "+nDangerous+" unary rule weights unchanged since the proposed weight was dangerous.");

		grammar.explicitlyComputeScores(finalLevel);
		grammar.closedSumRulesWithParent = grammar.closedViterbiRulesWithParent = grammar.unaryRulesWithParent;
		grammar.closedSumRulesWithChild = grammar.closedViterbiRulesWithChild = grammar.unaryRulesWithC;
//		computePairsOfUnaries();
		grammar.clearUnaryIntermediates();
		grammar.makeCRArrays();
//		return grammar;
	}

	public double[] getLinearizedGrammar(boolean update) {
		if (update){
//			int nRules = grammar.binaryRuleMap.size() + grammar.unaryRuleMap.size();
//			startIndex = new int[nRules];
			
			nGrammarWeights = 0;
			for (BinaryRule bRule : grammar.binaryRuleMap.keySet()){
				HierarchicalAdaptiveBinaryRule hRule = (HierarchicalAdaptiveBinaryRule)bRule; 
				if (!grammar.isGrammarTag[bRule.parentState]){ System.out.println("Incorrect grammar tag"); }
				bRule.identifier = nGrammarWeights; 
				nGrammarWeights += hRule.nParam;
			}
			for (UnaryRule uRule : grammar.unaryRuleMap.keySet()){
				HierarchicalAdaptiveUnaryRule hRule = (HierarchicalAdaptiveUnaryRule)uRule; 
				uRule.identifier = nGrammarWeights;
				nGrammarWeights += hRule.nParam;
			}
		}
		double[] logProbs = new double[nGrammarWeights];

		for (BinaryRule bRule : grammar.binaryRuleMap.keySet()){
			HierarchicalAdaptiveBinaryRule hRule = (HierarchicalAdaptiveBinaryRule)bRule; 
			int ind = hRule.identifier;//startIndex[ruleIndexer.indexOf(hRule)];
			List<Double> vals = hRule.getFinalLevel();
			for (Double val : vals){
				logProbs[ind++] = val;
			}
		}

		for (UnaryRule uRule : grammar.unaryRuleMap.keySet()){
			HierarchicalAdaptiveUnaryRule hRule = (HierarchicalAdaptiveUnaryRule)uRule; 
			int ind = hRule.identifier;//startIndex[ruleIndexer.indexOf(hRule)];
			if (uRule.childState==uRule.parentState) continue;
			List<Double> vals = hRule.getFinalLevel();
			for (Double val : vals){
				logProbs[ind++] = val;
			}
		}
		return logProbs;
	}

	
	public void delinearizeLexiconWeights(double[] logWeights) {
		int nGrZ=0, nLexZ=0, nSpZ=0;

		int tmpI = 0;
	    for (int i=0; i<nGrammarWeights; i++){
	    	double val = logWeights[tmpI++];
	    	if (val==0) nGrZ++;
	    }    
	
	    for (int i=0; i<nLexiconWeights; i++){
	    	double val = logWeights[tmpI++];
	    	if (val==0) nLexZ++;
	    }    
	    delinearizeLexicon(logWeights, true);
	}
	

}
